import abc
from enum import Enum
import sys
from functools import lru_cache
from typing import Any, List, Optional, Type

import country_converter as coco
from chainlite import chain
from pydantic import BaseModel, Field, create_model

sys.path.insert(0, "./")
from log_utils import get_logger

logger = get_logger(__name__)

cc = coco.CountryConverter()


@lru_cache
def normalize_country_name(country_name: str) -> str:
    ret = cc.convert(names=country_name, to="name_short", not_found="not found")
    if ret == "not found":
        # e.g. sometimes LLM output Europe as a country name
        ret = None
    if isinstance(ret, list):
        logger.warning(
            "When normalizing %s, output was a list: %s", country_name, str(ret)
        )
        ret = ret[0]
    return ret


class Location(BaseModel):
    """
    Represent a location mention in the text.
    country, state_or_province and city can be None if they are not identifiable from the given text.
    """

    mention: str = Field(
        ..., description="The string from the text that refers to this location"
    )
    country: Optional[str] = Field(
        ...,
        description="Normalized name of a country, e.g. United States. If there are multiple countries, they should be put into separate Location objects",
    )
    state_or_province: Optional[str] = Field(
        ..., description="Normalized name of a state or province, e.g. California"
    )
    city: Optional[str] = Field(..., description="Name of a city, e.g. Tehran")
    address: Optional[str] = Field(..., description="Street name, building name etc.")

    def __init__(
        self,
        mention: str,
        country: Optional[str] = None,
        state_or_province: Optional[str] = None,
        city: Optional[str] = None,
        address: Optional[str] = None,
    ):
        self.mention = mention
        if isinstance(country, str):
            country = country.strip()
        if country:
            country = normalize_country_name(country)
        self.country = country
        if isinstance(state_or_province, list):
            if len(state_or_province) > 0:
                state_or_province = state_or_province[0]
            else:
                state_or_province = None
        if isinstance(state_or_province, str):
            state_or_province = state_or_province.strip()
        self.state_or_province = state_or_province

        if isinstance(city, list):
            if len(city) > 0:
                city = city[0]
            else:
                city = None
        if isinstance(city, str):
            city = city.strip()
        self.city = city
        self.address = address


class AbstractEvent(BaseModel, abc.ABC):
    """A general class to represent events."""

    def __init__(self, **data: Any):
        super().__init__(**data)
        for field_name, field_value in self.__dict__.items():
            if isinstance(field_value, list):
                field_value.sort()

    def to_string_ea(self, include_arguments=True, include_summary=False):
        fields_to_include = list(self.__dict__.keys()) if include_arguments else []
        if not include_summary:
            fields_to_include = [
                field for field in fields_to_include if field != "summary"
            ]
        fields_str = ", ".join(
            f"{field_name}={repr(self.__dict__[field_name])}"
            for field_name in fields_to_include
        )
        return f"{self.__class__.__name__}({fields_str})"

    def get_event_type(self) -> str:
        return self.__class__.__name__

    def __hash__(self):
        # Create a tuple of all fields for hashing, converting lists to tuples
        return hash(
            tuple(
                (
                    field_name,
                    (
                        tuple(getattr(self, field_name))
                        if isinstance(getattr(self, field_name), list)
                        else getattr(self, field_name)
                    ),
                )
                for field_name in self.model_fields
            )
        )

    def __eq__(self, other):
        if not isinstance(other, AbstractEvent):
            return NotImplemented
        # Compare all fields for equality, converting lists to tuples
        return all(
            (
                tuple(getattr(self, field_name))
                if isinstance(getattr(self, field_name), list)
                else getattr(self, field_name)
            )
            == (
                tuple(getattr(other, field_name))
                if isinstance(getattr(other, field_name), list)
                else getattr(other, field_name)
            )
            for field_name in self.model_fields
        )